-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][memref] Add runtime verification for memref.dim
#130410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][memref] Add runtime verification for memref.dim
#130410
Conversation
|
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesAdd runtime verification for Also simplify the pass pipeline for all memref runtime verification checks. Full diff: https://github.com/llvm/llvm-project/pull/130410.diff 7 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index f93ae0a7a298f..f825d7d9d42c2 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -23,6 +23,18 @@ using namespace mlir;
namespace mlir {
namespace memref {
namespace {
+/// Generate a runtime check for lb <= value < ub.
+Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
+ Value lb, Value ub) {
+ Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, value, lb);
+ Value inBounds2 = builder.createOrFold<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::slt, value, ub);
+ Value inBounds =
+ builder.createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
+ return inBounds;
+}
+
struct CastOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
CastOp> {
@@ -128,6 +140,21 @@ struct CastOpInterface
}
};
+struct DimOpInterface
+ : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
+ DimOp> {
+ void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+ Location loc) const {
+ auto dimOp = cast<DimOp>(op);
+ Value rank = builder.create<RankOp>(loc, dimOp.getSource());
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ builder.create<cf::AssertOp>(
+ loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "index is out of bounds"));
+ }
+};
+
/// Verifies that the indices on load/store ops are in-bounds of the memref's
/// index space: 0 <= index#i < dim#i
template <typename LoadStoreOp>
@@ -148,19 +175,12 @@ struct LoadStoreOpInterface
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value assertCond;
for (auto i : llvm::seq<int64_t>(0, rank)) {
- auto index = indices[i];
-
- auto dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
-
- auto geLow = builder.createOrFold<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, index, zero);
- auto ltHigh = builder.createOrFold<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, index, dimOp);
- auto andOp = builder.createOrFold<arith::AndIOp>(loc, geLow, ltHigh);
-
+ Value dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
+ Value inBounds =
+ generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
assertCond =
- i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, andOp)
- : andOp;
+ i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
+ : inBounds;
}
builder.create<cf::AssertOp>(
loc, assertCond,
@@ -335,6 +355,7 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
CastOp::attachInterface<CastOpInterface>(*ctx);
+ DimOp::attachInterface<DimOpInterface>(*ctx);
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
diff --git a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
index 62db9ce1316ae..a40bc2b3272fc 100644
--- a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
+++ b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
@@ -28,11 +28,19 @@ struct GenerateRuntimeVerificationPass
} // namespace
void GenerateRuntimeVerificationPass::runOnOperation() {
+ // The implementation of the RuntimeVerifiableOpInterface may create ops that
+ // can be verified. We don't want to generate verification for IR that
+ // performs verification, so gather all runtime-verifiable ops first.
+ SmallVector<RuntimeVerifiableOpInterface> ops;
getOperation()->walk([&](RuntimeVerifiableOpInterface verifiableOp) {
- OpBuilder builder(getOperation()->getContext());
+ ops.push_back(verifiableOp);
+ });
+
+ OpBuilder builder(getOperation()->getContext());
+ for (RuntimeVerifiableOpInterface verifiableOp : ops) {
builder.setInsertionPoint(verifiableOp);
verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
- });
+ };
}
std::unique_ptr<Pass> mlir::createGenerateRuntimeVerificationPass() {
diff --git a/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir
index b101a875154ff..8b6308e9c1939 100644
--- a/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir
@@ -1,8 +1,7 @@
-// RUN: mlir-opt %s -generate-runtime-verification -finalize-memref-to-llvm \
+// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -test-cf-assert \
-// RUN: -convert-func-to-llvm \
-// RUN: -convert-arith-to-llvm \
-// RUN: -reconcile-unrealized-casts | \
+// RUN: -expand-strided-metadata \
+// RUN: -convert-to-llvm | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir
new file mode 100644
index 0000000000000..2e3f271743c93
--- /dev/null
+++ b/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -expand-strided-metadata \
+// RUN: -test-cf-assert \
+// RUN: -convert-to-llvm | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+func.func @main() {
+ %c4 = arith.constant 4 : index
+ %alloca = memref.alloca() : memref<1xf32>
+
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK-NEXT: "memref.dim"(%{{.*}}, %{{.*}}) : (memref<1xf32>, index) -> index
+ // CHECK-NEXT: ^ index is out of bounds
+ // CHECK-NEXT: Location: loc({{.*}})
+ %dim = memref.dim %alloca, %c4 : memref<1xf32>
+
+ return
+}
diff --git a/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir
index d6c5d6da0041e..b87e5bdf0970c 100644
--- a/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir
@@ -1,10 +1,8 @@
// RUN: mlir-opt %s -generate-runtime-verification \
-// RUN: -expand-strided-metadata \
-// RUN: -finalize-memref-to-llvm \
// RUN: -test-cf-assert \
-// RUN: -convert-func-to-llvm \
-// RUN: -convert-arith-to-llvm \
-// RUN: -reconcile-unrealized-casts | \
+// RUN: -expand-strided-metadata \
+// RUN: -lower-affine \
+// RUN: -convert-to-llvm | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir
index 9fea48bdfc07d..601a53f4b5cd9 100644
--- a/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir
@@ -1,10 +1,8 @@
// RUN: mlir-opt %s -generate-runtime-verification \
-// RUN: -lower-affine \
-// RUN: -finalize-memref-to-llvm \
// RUN: -test-cf-assert \
-// RUN: -convert-func-to-llvm \
-// RUN: -convert-arith-to-llvm \
-// RUN: -reconcile-unrealized-casts | \
+// RUN: -expand-strided-metadata \
+// RUN: -lower-affine \
+// RUN: -convert-to-llvm | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
index 66474e9c4ae37..3cac37a082c30 100644
--- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
@@ -1,11 +1,8 @@
// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -test-cf-assert \
// RUN: -expand-strided-metadata \
// RUN: -lower-affine \
-// RUN: -finalize-memref-to-llvm \
-// RUN: -test-cf-assert \
-// RUN: -convert-func-to-llvm \
-// RUN: -convert-arith-to-llvm \
-// RUN: -reconcile-unrealized-casts | \
+// RUN: -convert-to-llvm | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
|
|
@llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesAdd runtime verification for Also simplify the pass pipeline for all memref runtime verification checks. Full diff: https://github.com/llvm/llvm-project/pull/130410.diff 7 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index f93ae0a7a298f..f825d7d9d42c2 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -23,6 +23,18 @@ using namespace mlir;
namespace mlir {
namespace memref {
namespace {
+/// Generate a runtime check for lb <= value < ub.
+Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value,
+ Value lb, Value ub) {
+ Value inBounds1 = builder.createOrFold<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, value, lb);
+ Value inBounds2 = builder.createOrFold<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::slt, value, ub);
+ Value inBounds =
+ builder.createOrFold<arith::AndIOp>(loc, inBounds1, inBounds2);
+ return inBounds;
+}
+
struct CastOpInterface
: public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
CastOp> {
@@ -128,6 +140,21 @@ struct CastOpInterface
}
};
+struct DimOpInterface
+ : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface,
+ DimOp> {
+ void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+ Location loc) const {
+ auto dimOp = cast<DimOp>(op);
+ Value rank = builder.create<RankOp>(loc, dimOp.getSource());
+ Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ builder.create<cf::AssertOp>(
+ loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank),
+ RuntimeVerifiableOpInterface::generateErrorMessage(
+ op, "index is out of bounds"));
+ }
+};
+
/// Verifies that the indices on load/store ops are in-bounds of the memref's
/// index space: 0 <= index#i < dim#i
template <typename LoadStoreOp>
@@ -148,19 +175,12 @@ struct LoadStoreOpInterface
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value assertCond;
for (auto i : llvm::seq<int64_t>(0, rank)) {
- auto index = indices[i];
-
- auto dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
-
- auto geLow = builder.createOrFold<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, index, zero);
- auto ltHigh = builder.createOrFold<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, index, dimOp);
- auto andOp = builder.createOrFold<arith::AndIOp>(loc, geLow, ltHigh);
-
+ Value dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i);
+ Value inBounds =
+ generateInBoundsCheck(builder, loc, indices[i], zero, dimOp);
assertCond =
- i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, andOp)
- : andOp;
+ i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds)
+ : inBounds;
}
builder.create<cf::AssertOp>(
loc, assertCond,
@@ -335,6 +355,7 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
CastOp::attachInterface<CastOpInterface>(*ctx);
+ DimOp::attachInterface<DimOpInterface>(*ctx);
ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx);
ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx);
diff --git a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
index 62db9ce1316ae..a40bc2b3272fc 100644
--- a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
+++ b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
@@ -28,11 +28,19 @@ struct GenerateRuntimeVerificationPass
} // namespace
void GenerateRuntimeVerificationPass::runOnOperation() {
+ // The implementation of the RuntimeVerifiableOpInterface may create ops that
+ // can be verified. We don't want to generate verification for IR that
+ // performs verification, so gather all runtime-verifiable ops first.
+ SmallVector<RuntimeVerifiableOpInterface> ops;
getOperation()->walk([&](RuntimeVerifiableOpInterface verifiableOp) {
- OpBuilder builder(getOperation()->getContext());
+ ops.push_back(verifiableOp);
+ });
+
+ OpBuilder builder(getOperation()->getContext());
+ for (RuntimeVerifiableOpInterface verifiableOp : ops) {
builder.setInsertionPoint(verifiableOp);
verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
- });
+ };
}
std::unique_ptr<Pass> mlir::createGenerateRuntimeVerificationPass() {
diff --git a/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir
index b101a875154ff..8b6308e9c1939 100644
--- a/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir
@@ -1,8 +1,7 @@
-// RUN: mlir-opt %s -generate-runtime-verification -finalize-memref-to-llvm \
+// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -test-cf-assert \
-// RUN: -convert-func-to-llvm \
-// RUN: -convert-arith-to-llvm \
-// RUN: -reconcile-unrealized-casts | \
+// RUN: -expand-strided-metadata \
+// RUN: -convert-to-llvm | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir
new file mode 100644
index 0000000000000..2e3f271743c93
--- /dev/null
+++ b/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -expand-strided-metadata \
+// RUN: -test-cf-assert \
+// RUN: -convert-to-llvm | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+func.func @main() {
+ %c4 = arith.constant 4 : index
+ %alloca = memref.alloca() : memref<1xf32>
+
+ // CHECK: ERROR: Runtime op verification failed
+ // CHECK-NEXT: "memref.dim"(%{{.*}}, %{{.*}}) : (memref<1xf32>, index) -> index
+ // CHECK-NEXT: ^ index is out of bounds
+ // CHECK-NEXT: Location: loc({{.*}})
+ %dim = memref.dim %alloca, %c4 : memref<1xf32>
+
+ return
+}
diff --git a/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir
index d6c5d6da0041e..b87e5bdf0970c 100644
--- a/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir
@@ -1,10 +1,8 @@
// RUN: mlir-opt %s -generate-runtime-verification \
-// RUN: -expand-strided-metadata \
-// RUN: -finalize-memref-to-llvm \
// RUN: -test-cf-assert \
-// RUN: -convert-func-to-llvm \
-// RUN: -convert-arith-to-llvm \
-// RUN: -reconcile-unrealized-casts | \
+// RUN: -expand-strided-metadata \
+// RUN: -lower-affine \
+// RUN: -convert-to-llvm | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir
index 9fea48bdfc07d..601a53f4b5cd9 100644
--- a/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/reinterpret-cast-runtime-verification.mlir
@@ -1,10 +1,8 @@
// RUN: mlir-opt %s -generate-runtime-verification \
-// RUN: -lower-affine \
-// RUN: -finalize-memref-to-llvm \
// RUN: -test-cf-assert \
-// RUN: -convert-func-to-llvm \
-// RUN: -convert-arith-to-llvm \
-// RUN: -reconcile-unrealized-casts | \
+// RUN: -expand-strided-metadata \
+// RUN: -lower-affine \
+// RUN: -convert-to-llvm | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
index 66474e9c4ae37..3cac37a082c30 100644
--- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
@@ -1,11 +1,8 @@
// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN: -test-cf-assert \
// RUN: -expand-strided-metadata \
// RUN: -lower-affine \
-// RUN: -finalize-memref-to-llvm \
-// RUN: -test-cf-assert \
-// RUN: -convert-func-to-llvm \
-// RUN: -convert-arith-to-llvm \
-// RUN: -reconcile-unrealized-casts | \
+// RUN: -convert-to-llvm | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
ryanpholt
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
6883972 to
a44867a
Compare
Add runtime verification for
memref.dim: check that the index is in bounds.Also simplify the pass pipeline for all memref runtime verification checks.